library(spmirt)
library(ggplot2)
library(datasim)
n <- 8000

f <- list(
  prob ~ I(0.5) +
    gp(list(s1, s2), cor.model = "exp_cor", cor.params = list(phi = 0.02),
       sigma2 = 0.5),
  size ~ I(1)
  )
data <- sim_model(formula = f,
                  link_inv = list(pnorm, identity),
                  generator = rbinom,
                  n = n
                  # ,
                  # seed = 2
                  )

data <- dplyr::rename(data, gp = gp.list.prob)

ggplot(data, aes(s1, s2)) +
  geom_point(aes(col = gp)) +
  scale_colour_distiller(palette = "RdYlBu")

ggplot(data, aes(s1, s2)) +
  geom_point(aes(col = factor(response)), size = 2)

vg0 <- gstat::variogram(gp ~ 1, ~ s1 + s2, data, cutoff = 0.7, width = 0.005)
ggplot(vg0, aes(dist, gamma, weight = np)) +
  geom_point(aes(size = np)) +
  expand_limits(y = 0, x = 0) +
  scale_x_continuous(limits = c(0, 0.7))

get_sill <- function (beta0, sigma2) {
  prop <- mean(pnorm(rnorm(10^5, beta0, sd = sigma2^0.5)))
  return(prop * (1 - prop))
}

get_sigma2 <- function (beta0, sigma2) {
  sigma2 <- var(pnorm(rnorm(10^5, beta0, sd = sigma2^0.5)))
  return(sigma2)
}

get_nugget <- function (beta0, sigma2) {
  nugget <- get_sill(beta0, sigma2) - get_sigma2(beta0, sigma2)
  return(nugget)
}

vg <- gstat::variogram(response ~ 1, ~ s1 + s2, data, cutoff = 0.5, width = 0.005)
nugget <- get_nugget(0.5, 0.5)
sigma2 <- get_sigma2(0.5, 0.5)
phi_aux <- 0.02
vg$theor <- nugget + sigma2 * (1 - exp(-vg$dist / phi_aux))
vg_fit <- gstat::fit.variogram(vg,
  vgm(0.5 * var(data$response),"Exp", nugget = 0.5 * var(data$response)))
nugget_y <- vg_fit$psill[1]
sigma2_y <- vg_fit$psill[2]
phi_y <- vg_fit$range[2]
vg$estim <- nugget_y + sigma2_y * (1 - exp(-vg$dist / phi_y))
ggplot(vg, aes(dist, gamma)) +
  geom_point(aes(size = np)) +
  geom_line(aes(y = theor, col = "theor")) +
  geom_line(aes(y = estim, col = "estima")) +
  scale_x_continuous(limits = c(0, 0.5)) +
  expand_limits(y = 0, x = 0)

cost_vg_fit <- function (param, vg) {
  beta0 <- param[1]
  sigma2 <- exp(param[2])
  phi <- exp(param[3])
  x <- rnorm(10^4, beta0, sd = sigma2^0.5)
  prop_bi <- mean(pnorm(x))
  sigma2_bi <- var(pnorm(x))
  sill_bi <- prop_bi * (1-prop_bi)
  nugget_bi <- sill_bi - sigma2_bi
  theor <- nugget_bi + sigma2_bi * (1 - exp(-vg$dist/phi))
  log(sum(vg$np * (theor - vg$gamma) ^ 2))
}

mesh <- expand.grid(beta0 = seq(0, 0.7, 0.01), sigma2 = seq(0.0, 1., 0.01))
mesh$cost <- purrr::map2_dbl(mesh$beta0, mesh$sigma2,
                             ~ cost_vg_fit(c(.x, log(.y), log(0.02)), vg))
ggplot(mesh, aes(beta0, sigma2)) +
  geom_raster(aes(fill = cost)) +
  scale_fill_distiller(palette = "RdYlBu")

cost_moments <- function (param, tomatch) {
  beta0 <- param[1]
  sigma2 <- exp(param[2])
  x <- rnorm(10^5, beta0, sd = sigma2^0.5)
  mean_bi <- mean(pnorm(x))
  var_bi <- var(pnorm(x))
  return(log((mean_bi - tomatch[1]) ^ 2 + (var_bi - tomatch[2]) ^ 2))
}
par <- optim(c(0, log(0.5)), cost_moments,
             tomatch = c(
                         mean(data$response),
                         var(data$response) - vg_fit$psill[1]))
c(par$par[1], exp(par$par[2]))
## [1] 0.5138506 0.5181197
mesh <- expand.grid(beta0 = seq(0.25, 0.75, 0.01), sigma2 = seq(0.25, 0.75, 0.01))
mesh$cost <- purrr::map2_dbl(mesh$beta0, mesh$sigma2,
                             ~ cost_moments(c(.x, log(.y)),
                                        c(mean(data$response),
                                          var(data$response)- 0.18)))
ggplot(mesh, aes(beta0, sigma2)) +
  geom_raster(aes(fill = cost)) +
  # scale_fill_viridis_c()
  scale_fill_distiller(palette = "RdYlBu")